# This file reads in atmospheric fluxes from a hdf5 file in the nuSQuIDS format and propagtes them through the earth using nuSQuIDS. 
# It saves the propagated fluxes to a .npz file so they can be interpolated later, and outputs a plot for verification. 
# adapted by A. Wen from the nuSQuIDS demo file.

# usage: python atm_flux_propagation_fromFile.py atmospheric.hdf5 propagated
# option 'propagated" will refer to an input file that has already earth-propagated fluxes, so there is no need to propagate it again here. 

# imports 
################################################################################################################
# nuSQuIDS: https://github.com/arguelles/nuSQuIDS
import nuSQuIDS as nsq 

import matplotlib.pyplot as plt
import numpy as np
import sys


# choose file to input
################################################################################################################
file_name = sys.argv[1]
propagation_status = sys.argv[2]
#example: 'atmospheric_0_0.000000_0.000000_0.000000_0.000000_0.000000_0.000000.hdf5'
print('Making fluxes with file '+file_name)
if propagation_status == 'propagated':
    print('fluxes already propagated.')
else:
    print('propagating fluxes...')

# nusquids setup
################################################################################################################
units = nsq.Const()

interactions = True

E_min = 50.0*units.GeV
E_max = 5000000*units.GeV
E_nodes = 150
energy_nodes = nsq.logspace(E_min,E_max,E_nodes)

cth_min = -1.0
cth_max = 0.1
cth_nodes = 100
cth_nodes = nsq.linspace(cth_min,cth_max,cth_nodes)

neutrino_flavors = 3

nsq_atm = nsq.nuSQUIDSAtm(cth_nodes,energy_nodes,neutrino_flavors,nsq.NeutrinoType.both,interactions)

# set the state from the file and propagate
################################################################################################################
nsq_atm.ReadStateHDF5(file_name)

nsq_atm.Set_ProgressBar(True) # progress bar will be printed on terminal

nsq_atm.Set_rel_error(1.0e-15)
nsq_atm.Set_abs_error(1.0e-15)

if propagation_status != 'propagated':
    nsq_atm.EvolveState()

erange = nsq_atm.GetERange()
cthrange = nsq_atm.GetCosthRange()

# make a plot to show the effect of the propagation
################################################################################################################
neutype = 1
phi_mubar = [nsq_atm.EvalFlavor(1,-0.5,EE,neutype) for EE in erange]
neutype = 0
phi_mu = [nsq_atm.EvalFlavor(1,-0.5,EE,neutype) for EE in erange]

plt.figure(figsize = (6,6))

plt.plot(erange/1e9, phi_mu, lw = 2.5, color = "cornflowerblue", label = r"$\nu_\mu$ flux at detector")
plt.plot(erange/1e9, phi_mubar, lw = 2.5, color = "darkblue", label = r"$\overline{\nu}_\mu$ flux at detector")

plt.loglog()

plt.xlim(erange[0]/1e9,erange[-1]/1e9)
plt.xlabel(r"$E_\nu \: [{\rm GeV}]$")
plt.ylabel(r"$\phi^{atm}_\nu (E_\nu, \cos(\theta)=-0.5) \: [1/({\rm GeV}\:{\rm s}\:{\rm cm}^2\:{\rm sr})]$")

plt.grid()

plt.legend()

plt.savefig(file_name.rstrip('.hdf5')+'.pdf')

# save the propagated fluxes as a function of energy and theta, they can be interpolated and eval'ed later: 
# (we care about muon flavor only)
################################################################################################################################
energy_nodes = nsq_atm.GetERange()
cth_nodes = nsq_atm.GetCosthRange()
AtmMuFinalFlux = np.zeros((2,len(cth_nodes),len(energy_nodes)))
Enodes = np.asarray(energy_nodes) / 1e9
Cthnodes = np.asarray(cth_nodes)
for ic,cth in enumerate(nsq_atm.GetCosthRange()):
    for ie,E in enumerate(nsq_atm.GetERange()):
        AtmMuFinalFlux[0][ic][ie] = nsq_atm.EvalFlavor(1,cth,E,0)
        AtmMuFinalFlux[1][ic][ie] = nsq_atm.EvalFlavor(1,cth,E,1)

np.savez('AtmMuFinalFlux'+file_name.rstrip('.hdf5')+'.npz', cth_nodes=Cthnodes, energy_nodes=Enodes, flux=AtmMuFinalFlux)
print('Propagated fluxes saved to '+str('AtmMuFinalFlux'+file_name.rstrip('.hdf5')+'.npz'))